import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        Solution solution = new Solution();
        int n = scanner.nextInt();
        int m = scanner.nextInt();
        solution.differenceOfSums(n,m);
    }
}

class Solution {
    public int differenceOfSums(int n, int m) {
        int sum=0;
        for (int i = 1; i <= n ; i++) {
            if(i%m==0){
                sum+=i;
            }
        }
        return (1+n)*n/2-sum*2;
    }
}